"""

python main_Het.py with algorithm_mode='hetdps' env_name=\'pettingzoo:pz-mpe-large-spread-v1\' time_limit=50 parallel_envs=32 experiment_label='NeurIPS' 
python main_Het.py with algorithm_mode='hetdps' env_name=\'pettingzoo:pz-mpe-large-spread-v2\' time_limit=50 parallel_envs=32 experiment_label='NeurIPS' 
python main_Het.py with algorithm_mode='hetdps' env_name=\'pettingzoo:pz-mpe-large-spread-v3\' time_limit=50 parallel_envs=32 experiment_label='NeurIPS' 
python main_Het.py with algorithm_mode='hetdps' env_name=\'pettingzoo:pz-mpe-large-spread-v4\' time_limit=50 parallel_envs=32 experiment_label='NeurIPS' 

Environment directory path: /home/MarlResearch/anaconda3/envs/IHet/lib/python3.8/site-packages/pettingzoo/mpe

"""

import numpy as np
import torch
import glob
import logging
import os
import shutil
import time
from collections import deque
from functools import partial
from os import path
from pathlib import Path
from collections import defaultdict
import json

import pickle
from cpprb import ReplayBuffer, create_before_add_func, create_env_dict
import gym
from sacred import Experiment
from sacred.observers import (
    FileStorageObserver,
    MongoObserver,  # noqa
    QueuedMongoObserver,
    QueueObserver,
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from torch.utils.tensorboard import SummaryWriter

from model_Het import Policy
from utils_Het import compute_fusions, compute_normal_SND, compute_implicit_het, compute_dynamic_parameter_sharing, ops_ingredient
from wrappers import *

import seaborn as sns
import matplotlib.pyplot as plt


ex = Experiment(ingredients=[ops_ingredient])
# ex.captured_out_filter = apply_backspaces_and_linefeeds
ex.captured_out_filter = lambda captured_output: "Output capturing turned off."
ex.observers.append(FileStorageObserver('test/'))

logging.basicConfig(
    level=logging.INFO,
    format="(%(process)d) [%(levelname).1s] - (%(asctime)s) >> %(message)s",
    datefmt="%m/%d %H:%M:%S",
)


import subprocess

def get_gpu_memory_map():
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ], encoding='utf-8')

    memory_used = [int(x) for x in result.strip().split('\n')]
    return memory_used

def select_least_used_gpu():
    """Select the GPU with the least memory usage."""
    memory_map = get_gpu_memory_map()
    return f'cuda:{memory_map.index(min(memory_map))}'

@ex.config
def config(ops):
    name = "MAPS"
    version = 0

    env_name = None
    time_limit = None
    env_args = {}

    wrappers = (
        RecordEpisodeStatistics,
        SquashDones,
        SMACCompatible,
    )
    dummy_vecenv = True

    seed = 66666 # 10086 12086 13086

    # everything below is update steps (not env steps!)
    total_steps = int(10e6)
    log_interval = int(100)
    save_interval = int(1e6)
    eval_interval = int(1e4)

    architecture = {
        "actor": [64, 64],
        "critic": [64, 64],
    }

    lr = 5e-4
    optim_eps = 0.00001

    parallel_envs = 32
    n_steps = 5
    gamma = 0.99
    entropy_coef = 0.01
    value_loss_coef = 0.5
    use_proper_termination = True
    central_v = False

    # 
    algorithm_mode = "hetdps" 
    if_measure = True  # Measure Policy Distance
    continue_IHet_training = True



    # device = "cpu"
    device = select_least_used_gpu() # auto-gpu or cpu
    # device = "cuda:0" # auto-gpu or cpu

    experiment_label = "official"

    # Add PPO and GAE related parameters
    ppo_epochs = 4         # Number of PPO update epochs
    ppo_clip_param = 0.2   # PPO clipping parameter
    gae_lambda = 0.95      # GAE lambda parameter
    value_clip_param = 0.2 # Value function clipping parameter
    normalize_advantage = True  # Whether to normalize advantages
    max_grad_norm = 0.5    # Gradient clipping value





class Torcherize(VecEnvWrapper):
    @ex.capture
    def __init__(self, venv, algorithm_mode):
        super().__init__(venv)
        self.observe_agent_id = 'id' in algorithm_mode
        if self.observe_agent_id:
            agent_count = len(self.observation_space)
            self.observation_space = gym.spaces.Tuple(tuple([gym.spaces.Box(low=-np.inf, high=np.inf, shape=((x.shape[0] + agent_count),), dtype=x.dtype) for x in self.observation_space]))

    @ex.capture
    def reset(self, device, parallel_envs):
        obs = self.venv.reset()
        obs = [torch.from_numpy(o).to(device) for o in obs]
        origin_obs = obs
        if self.observe_agent_id:
            ids = torch.eye(len(obs)).repeat_interleave(parallel_envs, 0).view(len(obs), -1, len(obs)).to(device)
            obs = [torch.cat((ids[i], obs[i]), dim=1) for i in range(len(obs))]
        #     return obs, origin_obs
        return obs
    

    def step_async(self, actions):
        actions = [a.squeeze().cpu().numpy() for a in actions]
        actions = list(zip(*actions))
        return self.venv.step_async(actions)

    @ex.capture
    def step_wait(self, device, parallel_envs):
        obs, rew, done, info = self.venv.step_wait()
        obs = [torch.from_numpy(o).float().to(device) for o in obs]
        origin_obs = obs
        if self.observe_agent_id:
            ids = torch.eye(len(obs)).repeat_interleave(parallel_envs, 0).view(len(obs), -1, len(obs)).to(device)
            obs = [torch.cat((ids[i], obs[i]), dim=1) for i in range(len(obs))]
        #     return (
        #         obs,
        #         origin_obs,
        #         torch.from_numpy(rew).float().to(device),
        #         torch.from_numpy(done).float().to(device),
        #         info,
        #     )
        return (
            obs,
            torch.from_numpy(rew).float().to(device),
            torch.from_numpy(done).float().to(device),
            info,
        )


class SMACWrapper(VecEnvWrapper):
    def _make_action_mask(self, n_agents):
        action_mask = self.venv.env_method("get_avail_actions")
        action_mask = [
            torch.tensor([avail[i] for avail in action_mask]) for i in range(n_agents)
        ]
        return action_mask

    def _make_state(self, n_agents):
        state = self.venv.env_method("get_state")
        state = torch.from_numpy(np.stack(state))
        return n_agents * [state]

    def reset(self):
        obs = self.venv.reset()
        state = self._make_state(len(obs))
        action_mask = self._make_action_mask(len(obs))
        return obs, state, action_mask

    def step_wait(self):
        obs, rew, done, info = self.venv.step_wait()
        state = self._make_state(len(obs))
        action_mask = self._make_action_mask(len(obs))

        return (
            (obs, state, action_mask),
            rew,
            done,
            info,
        )


@ex.capture
def _compute_returns(storage, next_value, gamma):
    returns = [next_value]
    for rew, done in zip(reversed(storage["rewards"]), reversed(storage["done"])):
        ret = returns[0] * gamma + rew * (1 - done.unsqueeze(1))
        returns.insert(0, ret)

    return returns


@ex.capture
def _make_envs(env_name, env_args, parallel_envs, dummy_vecenv, wrappers, time_limit, seed):
    def _env_thunk(seed):
        # print(env_args)
        env = gym.make(env_name, **env_args)
        if time_limit:
            env = TimeLimit(env, time_limit)
        for wrapper in wrappers:
            env = wrapper(env)
        env.seed(seed)
        return env

    env_thunks = [partial(_env_thunk, seed + i) for i in range(parallel_envs)]
    if dummy_vecenv:
        envs = DummyVecEnv(env_thunks)
        envs.buf_rews = np.zeros(
            (parallel_envs, len(envs.observation_space)), dtype=np.float32
        )
    else:
        envs = SubprocVecEnv(env_thunks, start_method="fork")
    envs = Torcherize(envs)
    envs = SMACWrapper(envs)
    return envs


def _squash_info(info):
    info = [i for i in info if i]
    new_info = {}
    keys = set([k for i in info for k in i.keys()])
    keys.discard("TimeLimit.truncated")
    for key in keys:
        mean = np.mean([np.array(d[key]).sum() for d in info if key in d])
        new_info[key] = mean
    return new_info


@ex.capture
def _log_progress(
    infos,
    prev_time,
    step,
    parallel_envs,
    n_steps,
    total_steps,
    log_interval,
    folder_path,
    _log,
    _run,
):
    elapsed = time.time() - prev_time
    ups = log_interval / elapsed
    fps = ups * parallel_envs * n_steps

    # Calculate metrics
    mean_reward = sum(sum([ep["episode_reward"] for ep in infos]) / len(infos))
    battles_won = 100 * sum([ep.get("battle_won", 0) for ep in infos]) / len(infos)

    # Log information
    _log.info(f"Updates {step}, Environment timesteps {parallel_envs* n_steps * step}")
    _log.info(f"UPS: {ups:.1f}, FPS: {fps:.1f}, ({100*step/total_steps:.2f}% completed)")
    _log.info(f"Last {len(infos)} episodes with mean reward: {mean_reward:.3f}")
    _log.info(f"Battles won: {battles_won:.1f}%")
    _log.info("-------------------------------------------")

    # Store raw data
    metrics_data = {
        "step": step,
        "mean_reward": mean_reward,
        "battles_won": battles_won,
        "fps": fps
    }
    
    # Append new data to JSON file
    metrics_file = os.path.join(folder_path, "training_metrics.json")
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            metrics_history = json.load(f)
    else:
        metrics_history = []
    
    metrics_history.append(metrics_data)
    with open(metrics_file, 'w') as f:
        json.dump(metrics_history, f, indent=4)

    # Plot training curves
    steps = [m['step'] for m in metrics_history]
    rewards = [m['mean_reward'] for m in metrics_history]
    win_rates = [m['battles_won'] for m in metrics_history]

    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Average reward curve (left Y-axis)
    color = 'tab:blue'
    ax1.set_xlabel('Steps')
    ax1.set_ylabel('Mean Reward', color=color)
    ax1.plot(steps, rewards, color=color, label='Mean Reward')
    ax1.tick_params(axis='y', labelcolor=color)

    # Win rate curve (right Y-axis)
    ax2 = ax1.twinx()
    color = 'tab:orange'
    ax2.set_ylabel('Win Rate (%)', color=color)
    ax2.plot(steps, win_rates, color=color, label='Win Rate')
    ax2.tick_params(axis='y', labelcolor=color)

    # Add legend
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

    plt.title('Training Progress')
    plt.tight_layout()
    
    # Save chart
    plt.savefig(os.path.join(folder_path, 'training_curves.png'))
    plt.close()

    return time.time()  # Return new timestamp


@ex.capture
def _compute_gae(storage, next_value, gamma, gae_lambda):
    """Calculate GAE (Generalized Advantage Estimation)"""
    gae = 0
    returns = []
    values = list(storage["values"])
    
    # Ensure next_value is on the correct device
    device = next_value.device
    
    # Add next_value to values list at the end
    values.append(next_value)
    
    # Reverse traversal to calculate GAE
    for step in reversed(range(len(storage["rewards"]))):
        # Ensure all tensors are on the same device
        rewards = storage["rewards"][step].to(device)
        done = storage["done"][step].to(device)
        delta = rewards + gamma * values[step + 1] * (1 - done.unsqueeze(1)) - values[step]
        gae = delta + gamma * gae_lambda * (1 - done.unsqueeze(1)) * gae
        returns.insert(0, gae + values[step])
    
    return returns

@ex.capture
def _compute_loss(model, optimizer, storage, value_loss_coef, entropy_coef, central_v, 
                 ppo_clip_param, normalize_advantage, value_clip_param, ppo_epochs, max_grad_norm):
    """Calculate PPO loss"""
    # Get next state value estimate

    

    with torch.no_grad():
        next_value = model.get_value(storage["state" if central_v else "obs"][-1])
    
    # Calculate GAE and returns
    returns = _compute_gae(storage, next_value)
    returns = torch.stack(returns)  # [steps, batch_size, n_agents]
    
    # Prepare input data
    obs = [torch.stack(list(zip(*storage["obs"]))[i])[:-1] for i in range(len(storage["obs"][0]))]
    if central_v:
        state = [torch.stack(list(zip(*storage["state"]))[i])[:-1] for i in range(len(storage["state"][0]))]
    else:
        state = None
    action_mask = [torch.stack(list(zip(*storage["action_mask"]))[i])[:-1] for i in range(len(storage["action_mask"][0]))]
    actions = [torch.stack(list(zip(*storage["actions"]))[i]) for i in range(len(storage["actions"][0]))]
    
    # Get old policy action probabilities and values
    with torch.no_grad():
        old_values, old_log_probs, _, _ = model.evaluate_actions(
                obs, actions, action_mask, state
            )
    
    total_loss = 0
    for _ in range(ppo_epochs):
        # Evaluate current policy
        optimizer.zero_grad()
        
        values, action_log_probs, entropy, _ = model.evaluate_actions(
                obs, actions, action_mask, state
            )
        
        # Calculate advantage
        advantage = returns - values.detach()
        if normalize_advantage:
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
        
        # Calculate policy ratio
        ratio = torch.exp(action_log_probs - old_log_probs)
        
        # PPO clipping target
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1.0 - ppo_clip_param, 1.0 + ppo_clip_param) * advantage
        policy_loss = -torch.min(surr1, surr2).mean()
        
        # Value function clipping loss
        value_pred_clipped = old_values + (values - old_values).clamp(-value_clip_param, value_clip_param)
        value_loss1 = (values - returns).pow(2)
        value_loss2 = (value_pred_clipped - returns).pow(2)
        value_loss = 0.5 * torch.max(value_loss1, value_loss2).mean()

        loss = policy_loss + value_loss_coef * value_loss - entropy_coef * entropy
        total_loss += loss

        loss.backward()
        # Use max_grad_norm parameter from configuration
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
    
    return total_loss / ppo_epochs


def visualize_distance_matrices(matrices_data, step, folder_path):
    """
    Plot multiple distance matrices as heatmaps and save
    
    Args:
        matrices_data (dict): Dictionary containing multiple distance matrices
        step (int): Current step
        folder_path (str): Save path
    """
    # Set chart size and layout
    n_matrices = len(matrices_data)
    n_cols = 2
    n_rows = (n_matrices + 1) // 2
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    axes = axes.flatten()
    
    # Find all matrices' maximum and minimum values for uniform color scale
    all_values = []
    for matrix in matrices_data.values():
        all_values.extend([item for sublist in matrix for item in sublist])
    vmin, vmax = np.min(all_values), np.max(all_values)
    
    # Plot each matrix heatmap
    for idx, (name, matrix) in enumerate(matrices_data.items()):
        matrix = np.array(matrix)
        ax = axes[idx]
        
        # Create heatmap
        sns.heatmap(
            matrix,
            ax=ax,
            cmap='YlOrRd',
            vmin=vmin,
            vmax=vmax,
            annot=True,  # Show specific values
            fmt='.2f',   # Format values as 2 decimal places
            square=True, # Keep square
            cbar_kws={'label': 'Distance'}
        )
        
        # Set title and labels
        ax.set_title(f'{name} Matrix')
        ax.set_xlabel('Agent Index')
        ax.set_ylabel('Agent Index')
    
    # Remove extra subplots
    for idx in range(len(matrices_data), len(axes)):
        fig.delaxes(axes[idx])
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(f"{folder_path}/distance_matrices_{step}.png", dpi=300, bbox_inches='tight')
    plt.close()


def backup_specific_code_files(target_dir):
    """
    Backup specified Python code files to target directory's code_backup subfolder
    
    Args:
        target_dir (str): Target directory path
    """
    # Create code backup folder
    backup_dir = os.path.join(target_dir, "code_backup")
    if not os.path.exists(backup_dir):
        os.makedirs(backup_dir)
    
    # Get current script directory
    current_dir = os.path.dirname(os.path.abspath(__file__))
    
    files_to_backup = [
        "main_Het.py",       
        "utils_Het.py",      
        "model_Het.py"           
    ]
    
    # Backup specified files
    for file in files_to_backup:
        src_file = os.path.join(current_dir, file)
        dst_file = os.path.join(backup_dir, file)
        if os.path.exists(src_file):
            shutil.copy2(src_file, dst_file)
            print(f"Backed up file: {file}")
        else:
            print(f"Warning: File {file} does not exist, cannot backup")



def visualize_hetdps_results(distance_matrix, new_clustering, network_assignments, step, folder_path):
    """
    Visualize HetDPS clustering and network assignment results
    
    Args:
        distance_matrix (torch.Tensor): Distance matrix
        new_clustering (numpy.ndarray or dict): Clustering results, could be array or dict
        network_assignments (list or numpy.ndarray): Network assignment results
        step (int): Current training step
        folder_path (str): Results save path
    """
    # Convert tensor to numpy array for plotting
    if isinstance(distance_matrix, torch.Tensor):
        distance_matrix = distance_matrix.cpu().numpy()
    
    # Create a figure with 3 subplots
    fig = plt.figure(figsize=(15, 10))
    
    # Left: Distance matrix heatmap
    ax1 = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
    sns.heatmap(
        distance_matrix,
        ax=ax1,
        cmap='YlOrRd',
        annot=True,
        fmt='.2f',
        square=True,
        cbar_kws={'label': 'Distance'}
    )
    ax1.set_title('Distance Matrix')
    ax1.set_xlabel('Agent Index')
    ax1.set_ylabel('Agent Index')
    
    # Right top: Clustering results visualization
    ax2 = plt.subplot2grid((2, 2), (0, 1))
    
    # Different processing based on new_clustering type
    clusters = {}
    if isinstance(new_clustering, dict):
        # If dict format, directly use
        for agent_id, cluster_id in new_clustering.items():
            if cluster_id not in clusters:
                clusters[cluster_id] = []
            clusters[cluster_id].append(agent_id)
    else:
        # If array format, convert to dict
        for agent_id, cluster_id in enumerate(new_clustering):
            if cluster_id not in clusters:
                clusters[cluster_id] = []
            clusters[cluster_id].append(agent_id)
    
    # Generate color mapping
    n_clusters = len(clusters)
    colors = plt.cm.nipy_spectral(np.linspace(0, 1, n_clusters))
    
    # Plot clustering
    for i, (cluster_id, agents) in enumerate(sorted(clusters.items())):
        for agent in agents:
            ax2.scatter(agent, 0, c=[colors[i]], s=100)
        # Plot labels
        agents_str = ', '.join(map(str, sorted(agents)))
        ax2.text(np.mean(agents), 0.1, f'Cluster {cluster_id}', 
                ha='center', va='bottom', fontsize=9, color=colors[i])
    
    ax2.set_title('Agent Clustering Results')
    ax2.set_xlabel('Agent ID')
    ax2.set_yticks([])
    ax2.set_xlim(-0.5, len(distance_matrix)-0.5)
    
    # Right bottom: Network assignment visualization
    ax3 = plt.subplot2grid((2, 2), (1, 1))
    
    # Ensure network_assignments is list or array
    if not isinstance(network_assignments, (list, np.ndarray)):
        network_assignments = [network_assignments[i] for i in range(len(distance_matrix))]
    
    # Get unique network assignments
    unique_assignments = sorted(set(network_assignments))
    colors_assignments = plt.cm.tab10(np.linspace(0, 1, len(unique_assignments)))
    
    # Plot each agent's network assignment
    for agent_id, network_id in enumerate(network_assignments):
        color_idx = unique_assignments.index(network_id)
        ax3.scatter(agent_id, 0, c=[colors_assignments[color_idx]], s=100)
        ax3.text(agent_id, 0.1, f'Network {network_id}', 
                ha='center', va='bottom', fontsize=9)
    
    ax3.set_title('Network Assignment Results')
    ax3.set_xlabel('Agent ID')
    ax3.set_yticks([])
    ax3.set_xlim(-0.5, len(network_assignments)-0.5)
    
    plt.tight_layout()
    plt.savefig(os.path.join(folder_path, f"hetdps_results_{step}.png"), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save clustering and network assignment results to text file
    with open(os.path.join(folder_path, f"hetdps_results_{step}.txt"), "w") as file:
        file.write(f"Step {step} - HetDPS Results\n\n")
        file.write("Clustering Results:\n")
        for cluster_id, agents in sorted(clusters.items()):
            file.write(f"Cluster {cluster_id}: {sorted(agents)}\n")
        file.write("\nNetwork Assignments:\n")
        for agent_id, network_id in enumerate(network_assignments):
            file.write(f"Agent {agent_id}: Network {network_id}\n")





@ex.automain
def main(
    _run,
    seed,
    total_steps,
    log_interval,
    save_interval,
    eval_interval,
    architecture,
    lr,
    optim_eps,
    parallel_envs,
    n_steps,
    use_proper_termination,
    central_v,
    max_grad_norm,
    ops,
    algorithm_mode,
    env_name,
    if_measure,
    device,
    _log,
    experiment_label,
):
    
    torch.set_num_threads(1)

    envs = _make_envs(seed=seed)


    # Modify file folder creation logic
    current_directory = os.path.dirname(os.path.abspath(__file__))
    
    # 1. Create main folder NeurIPS_MPE_X
    main_folder_name = f"NeurIPS_MPE_{env_name}"
    main_folder_path = os.path.join(current_directory, main_folder_name)
    if not os.path.exists(main_folder_path):
        os.makedirs(main_folder_path)
    
    # 2. Create subfolder algorithm_mode_experiment_label_date_time
    from datetime import datetime
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    sub_folder_name = f"{algorithm_mode}_{experiment_label}_{current_time}"
    folder_path = os.path.join(main_folder_path, sub_folder_name)
    os.makedirs(folder_path)
    print(f"Created folder: {folder_path}")
    
    # Backup specified Python code files
    print(f"Backing up code files to: {folder_path}")
    backup_specific_code_files(folder_path)



    agent_count = len(envs.observation_space)
    obs_size = envs.observation_space[0].shape
    obs_size = obs_size[0]
    origin_obs_size = obs_size - agent_count if 'id' in algorithm_mode else obs_size
    act_size = envs.action_space[0].n

    # temp version 
    local_state_size = origin_obs_size
    policy_input_size = obs_size + ops['z_features_IHet'] if 'IHet' in algorithm_mode else obs_size

    env_dict = {
        "obs": {"shape": policy_input_size, "dtype": np.float32},
        "origin_obs": {"shape": origin_obs_size, "dtype": np.float32},
        "local_state": {"shape": local_state_size, "dtype": np.float32},
        # "hidden": {"shape": obs_size, "dtype": np.float32},
        "rew": {"shape": 1, "dtype": np.float32},
        "rew_pad": {"shape": origin_obs_size, "dtype": np.float32},
        "next_obs": {"shape": origin_obs_size, "dtype": np.float32},
        "next_local_state": {"shape": origin_obs_size, "dtype": np.float32},
        "done": {"shape": 1, "dtype": np.float32},
        "act": {"shape": act_size, "dtype": np.float32},
        "act_probs": {"shape": act_size, "dtype": np.float32},
        "act_probs_with_mask": {"shape": act_size, "dtype": np.float32},
        "act_mask": {"shape": act_size, "dtype": np.float32},
        "agent": {"shape": agent_count, "dtype": np.float32},
    }
    rb = ReplayBuffer(int(agent_count * ops['max_rb_steps'] * parallel_envs * n_steps), env_dict) 

    state_size = envs.get_attr("state_size")[0] if central_v else None

    assert ops["model_count"]
    model_count = ops["model_count"]
    


    model = Policy(envs.observation_space, envs.action_space, architecture, model_count, state_size, initial_as_the_same=ops['initial_as_the_same'])
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr, eps=optim_eps)

    # creates and initialises storage
    obs, state, action_mask = envs.reset()

    
    last_origin_obs = [o for o in obs]

    storage = defaultdict(lambda: deque(maxlen=n_steps))
    storage["obs"] = deque(maxlen=n_steps + 1)
    storage["done"] = deque(maxlen=n_steps + 1)
    storage["values"] = deque(maxlen=n_steps + 1)  
    storage["old_log_probs"] = deque(maxlen=n_steps + 1)  
    storage["obs"].append(obs)
    storage["done"].append(torch.zeros(parallel_envs))
    storage["info"] = deque(maxlen=10)

    # for SMAC:
    storage["state"] = deque(maxlen=n_steps + 1)
    storage["action_mask"] = deque(maxlen=n_steps + 1)
    if central_v:
        storage["state"].append(state)
    storage["action_mask"].append(action_mask)
    # ---------

    # for IHet
    storage["local_state"] = deque(maxlen=n_steps + 1)
    storage["local_state"].append(last_origin_obs)
    storage["local_act"] = deque(maxlen=n_steps + 1)
    zero_act = torch.zeros(agent_count, parallel_envs, act_size).to(device)
    storage["local_act"].append(zero_act)

    model.sample_laac(parallel_envs)

    model.laac_shallow = torch.zeros(parallel_envs, agent_count).long()
    model.laac_deep = torch.zeros(parallel_envs, agent_count).long()


    

    print(model.laac_shallow)
    print(model.laac_deep)

    start_time = time.time()

    avg_BD_storage = []
    avg_Hellinger_storage = []
    avg_WD_storage = []
    avg_BD_vanilla_storage = []
    avg_Hellinger_vanilla_storage = []


    IHet_VAE_model = None

    if algorithm_mode == "hetdps":
        (previous_clustering, previous_network_assignments) = None, None

    for step in range(total_steps):

                
        if (algorithm_mode in ["hetdps"] or if_measure) and step in [ops["reparameter_interval"]*(i+1) for i in range(ops["reparameter_times"])]:
            
            # -----------------------------Model-Free Heterogeneity Measuring-----------------------------
            print(f"Measuring Implicit Heterogeneity at step: {step}")        

            compute_IHet = True # Ready for only training without measuring mode 
            IHet_BD, IHet_Hellinger, IHet_WD, IHet_VAE_model = compute_implicit_het(rb.get_all_transitions(), agent_count, model,
                                                                                    continue_IHet_training=True,
                                                                                    pretrained_VAE=IHet_VAE_model,
                                                                                    compute_IHet=compute_IHet)
            
            if algorithm_mode == "hetdps":
                print(f"Running HetDPS at step: {step}")
                # -----------------------------HetDPS-----------------------------
                new_clustering, network_assignments = compute_dynamic_parameter_sharing(IHet_WD, previous_clustering=previous_clustering, previous_network_assignments=previous_network_assignments, policy_model=model, merge_mode='majority', seed=seed)
                print(f"New clustering: {new_clustering}")
                (previous_clustering, previous_network_assignments) = new_clustering, network_assignments
                device = IHet_WD.device
                E = model.laac_shallow.shape[0]  
                network_tensor = torch.tensor(network_assignments, dtype=torch.long).to(device)
                model.laac_shallow = network_tensor.repeat(E, 1)  # Copy same allocation to each environment
                model.laac_deep = network_tensor.repeat(E, 1)     # Deep network uses same allocation
                visualize_hetdps_results(IHet_WD, new_clustering, network_assignments, step, folder_path)
                
                # Save clustering and network assignment results to extra text file for easy viewing
                with open(f"{folder_path}/hetdps_clustering.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(f"Clustering: {new_clustering}\n")
                    file.write(f"Network assignments: {network_assignments}\n\n")



            # -----------------------------Model-Based Heterogeneity Measuring-----------------------------
            only_measure = True
            policy_mode = 'Pure'
            policy_submodel = None

            BD, Hellinger, WD, laac_s, laac_d = compute_fusions(rb.get_all_transitions(), agent_count, model, measure_mode=only_measure, policy_mode=policy_mode, policy_submodel=policy_submodel)
            vanilla_BD, vanilla_Hellinger = compute_normal_SND(rb.get_all_transitions(), agent_count, model, policy_mode=policy_mode, policy_submodel=policy_submodel) 

            if IHet_BD is not None:
                matrices_data = {
                    'WD': WD.cpu().numpy().tolist(),
                    'IHet_WD': IHet_WD.cpu().numpy().tolist(),
                    'BD': BD.cpu().numpy().tolist(),
                    'IHet_BD': IHet_BD.cpu().numpy().tolist(),
                    # 'Hellinger': Hellinger.cpu().numpy().tolist(),
                    # 'IHet_Hellinger': IHet_Hellinger.cpu().numpy().tolist(),
                    'vanilla_BD': vanilla_BD.cpu().numpy().tolist(),
                    # 'vanilla_Hellinger': vanilla_Hellinger.cpu().numpy().tolist()
                }
                visualize_distance_matrices(matrices_data, step, folder_path)



                # Store raw JSON file
                (BD, Hellinger, WD, vanilla_BD, vanilla_Hellinger) = (BD.cpu().numpy(),Hellinger.cpu().numpy(),WD.cpu().numpy(),
                                                                    vanilla_BD.cpu().numpy(),vanilla_Hellinger.cpu().numpy())
                (IHet_BD, IHet_Hellinger, IHet_WD) = (IHet_BD.cpu().numpy(), IHet_Hellinger.cpu().numpy(), IHet_WD.cpu().numpy())

                avg_BD = BD.mean().item()
                avg_Hellinger = Hellinger.mean().item()
                avg_WD = WD.mean().item()
                avg_BD_vanilla = vanilla_BD.mean().item()
                avg_Hellinger_vanilla = vanilla_Hellinger.mean().item()
                
                avg_BD_storage.append(avg_BD)
                avg_Hellinger_storage.append(avg_Hellinger)
                avg_WD_storage.append(avg_WD)
                avg_BD_vanilla_storage.append(avg_BD_vanilla)
                avg_Hellinger_vanilla_storage.append(avg_Hellinger_vanilla)
                

                with open(os.path.join(folder_path, "avg_BD.json"), "w") as file:
                    json.dump(avg_BD_storage, file)
                with open(os.path.join(folder_path, "avg_Hellinger.json"), "w") as file:
                    json.dump(avg_Hellinger_storage, file)
                with open(os.path.join(folder_path, "avg_WD.json"), "w") as file:
                    json.dump(avg_WD_storage, file)
                with open(os.path.join(folder_path, "avg_BD_vanilla.json"), "w") as file:
                    json.dump(avg_BD_vanilla_storage, file)
                with open(os.path.join(folder_path, "avg_Hellinger_vanilla.json"), "w") as file:
                    json.dump(avg_Hellinger_vanilla_storage, file)


                
                (BD_r, Hellinger_r, WD_r, vanilla_BD_r, vanilla_Hellinger_r, IHet_BD_r, IHet_Hellinger_r, IHet_WD_r) = (
                    np.around(BD, 4), np.around(Hellinger, 4), np.around(WD, 4),
                    np.around(vanilla_BD, 4), np.around(vanilla_Hellinger, 4),
                    np.around(IHet_BD, 4), np.around(IHet_Hellinger, 4), np.around(IHet_WD, 4)
                )


                with open(f"{folder_path}/BD.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(BD_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/Hellinger.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(Hellinger_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/WD.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(WD_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/vanilla_BD.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(vanilla_BD_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/vanilla_Hellinger.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(vanilla_Hellinger_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/IHet_BD.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(IHet_BD_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/IHet_Hellinger.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(IHet_Hellinger_r))
                    file.write("\n\n")  

                with open(f"{folder_path}/IHet_WD.txt", "a") as file:
                    file.write(f"Step {step}:\n")
                    file.write(str(IHet_WD_r))
                    file.write("\n\n")  
            

            

            # pickle.dump(rb.get_all_transitions(), open(f"{env_name}.p", "wb"))
            _log.info(model.laac_shallow)
            _log.info(model.laac_deep)


        if step % log_interval == 0 and len(storage["info"]):
            start_time = _log_progress(
                storage["info"], 
                start_time, 
                step,
                parallel_envs,
                n_steps,
                total_steps,
                log_interval,
                folder_path,
                _log
            )
            storage["info"].clear()

        # ---------------------    Sampling from Environment    ---------------------
        for n_step in range(n_steps):
            # ---------------------    Get Actions ---------------------q
            with torch.no_grad():
                actions, act_probs_with_mask = model.act(storage["obs"][-1], storage["action_mask"][-1])
                act_probs = model.get_act_probs(storage["obs"][-1])           
                values = model.get_value(storage["state" if central_v else "obs"][-1])
                storage["values"].append(values)
                storage["local_state"].append(last_origin_obs)


            # ---------------------    Step Environment ---------------------
            (obs, state, action_mask), reward, done, info = envs.step(actions)
            last_origin_obs = [o.clone() for o in obs]


            if use_proper_termination:
                bad_done = torch.FloatTensor(
                    [1.0 if i.get("TimeLimit.truncated", False) else 0.0 for i in info]
                ).to(device)
                done = done - bad_done


            # ---------------------    Storage ---------------------
            one_hot_action = []
            for agent_actions in actions:
                one_hot = torch.nn.functional.one_hot(agent_actions.squeeze(-1), act_size).float()
                one_hot_action.append(one_hot)
            one_hot_action_tensor = torch.stack(one_hot_action, dim=0)
            storage["local_act"].append(one_hot_action_tensor)
            storage["obs"].append(obs)
            storage["actions"].append(actions)
            storage["rewards"].append(reward)
            storage["done"].append(done)
            storage["info"].extend([i for i in info if "episode_reward" in i])
            storage["laac_rewards"] += reward

            if (algorithm_mode in ["hetdps"] or if_measure) and step < (ops["delay_steps"] + ops["reparameter_times"] * ops["reparameter_interval"]):
                for agent in range(len(obs)):

                    one_hot_action = torch.nn.functional.one_hot(actions[agent], act_size).squeeze().cpu().numpy()
                    one_hot_agent = torch.nn.functional.one_hot(torch.tensor(agent), agent_count).repeat(parallel_envs, 1).numpy()

                    if bad_done[0]:
                        nobs = info[0]["terminal_observation"]
                        nobs = [torch.tensor(o).unsqueeze(0) for o in nobs]
                    else:
                        nobs = obs
                    
                    
                    origin_obs = storage["obs"][-2]

                    # padding reward, key trick
                    origin_obs_dim = origin_obs[agent].size(-1)
                    rew_pad = reward[:, agent].unsqueeze(-1).repeat(1, origin_obs_dim).cpu().numpy()

                    # This part creates sample pool for measuring training
                    data = {
                        "obs": storage["obs"][-2][agent].cpu().numpy(),
                        "origin_obs": origin_obs[agent].cpu().numpy(),
                        "local_state": origin_obs[agent].cpu().numpy(),
                        # "hidden": storage["obs"][-2][agent].cpu().numpy(), 
                        "act_mask": action_mask[agent].cpu().numpy(),  
                        "act": one_hot_action,
                        "act_probs": act_probs[agent].cpu().numpy(),  
                        "act_probs_with_mask": act_probs_with_mask[agent].cpu().numpy(),  
                        "next_obs": nobs[agent].cpu().numpy(),
                        "next_local_state": nobs[agent].cpu().numpy(),
                        "rew":  reward[:, agent].unsqueeze(-1).cpu().numpy(),
                        "rew_pad": rew_pad,
                        "done": done[:].unsqueeze(-1).cpu().numpy(),
                        # "policy": np.array([model.laac_sample[0, agent].float().item()]),
                        "agent": one_hot_agent,
                        # "timestep": step,
                        # "nstep": n_step,
                    }
                    try:
                        rb.add(**data)
                    except Exception as e:
                        print(e)
                        assert True

            # for smac:
            if central_v:
                storage["state"].append(state)

            storage["action_mask"].append(action_mask)
            # ---------

        if algorithm_mode == "seps" and step < ops["delay_steps"] and ops["delay_training_seps"]:
            continue


        loss = _compute_loss(model, optimizer, storage)



    envs.close()

